"""
@content: IQL
@author: 不去幼儿园
@Timeline: 2024.08.21
"""
import random
from torch import nn
import torch
import numpy as np
import collections
import matplotlib.pyplot as plt
from env_base import Env  # 自定义环境
import argparse
from scipy.signal import savgol_filter
import time
import os
from tensorboardX import SummaryWriter
writer = SummaryWriter()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device)
 
 
# 经验池
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity)
 
    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
 
    def sample(self, batch_size):
        transitions = random.sample(self.buffer, batch_size)  # list, len=32
        # *transitions代表取出列表中的值，即32项
        state, action, reward, next_state, done = zip(*transitions)
        return np.array(state), action, reward, np.array(next_state), done
 
    def size(self):
        return len(self.buffer)
 
 
# 神经网络
class DqnNet(nn.Module):
    def __init__(self, n_states, n_actions):
        super(DqnNet, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(n_states, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, n_actions)
        )
 
    def forward(self, state):
        return self.model(state)
 
 
# 智能体
class Agent(object):
    def __init__(self, identifier, n_states, n_hidden, n_actions, learning_rate, gamma, epsilon, target_update, device=None):
        # agent参数（智能体身体）
        self.identifier = identifier
        self.n_states = n_states
        self.n_hidden = n_hidden
        self.n_actions = n_actions
        self.learning_rate = learning_rate
        self.gamma = gamma
        self.epsilon = epsilon
        self.target_update = target_update
        self.device = device
 
        # agent网络（智能体大脑）
        self.Q_net = DqnNet(self.n_states, self.n_actions)
        self.Q_target_net = DqnNet(self.n_states, self.n_actions)
        self.optimizer = torch.optim.Adam(self.Q_net.parameters(), lr=self.learning_rate)
 
        self.count = 0
 
    def take_action(self, state):
 
        if np.random.uniform(0, 1) < self.epsilon:
            state = torch.tensor(state, dtype=torch.float)
            action = torch.argmax(self.Q_net(state)).item()
        else:
            action = np.random.randint(0, self.n_actions, 1)
 
        return int(action)
 
    def update(self, transition_dict):
        states = torch.tensor(transition_dict["states"], dtype=torch.float)
        actions = torch.tensor(transition_dict["actions"]).view(-1, 1)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1)
 
        predict_q_values = self.Q_net(states).gather(1, actions)  # [b,1]
        with torch.no_grad():
            max_next_q_values = self.Q_target_net(next_states).max(1)[0].view(-1, 1)
            q_targets = rewards + self.gamma * max_next_q_values * (1 - dones)
 
        dqn_loss = nn.MSELoss()(predict_q_values, q_targets)
        self.optimizer.zero_grad()
        dqn_loss.backward()
        self.optimizer.step()
 
        if self.count % self.target_update == 0:
            self.Q_target_net.load_state_dict(self.Q_net.state_dict())
        self.count += 1
 
 
# 训练主程序
def run(env, agent_list, replay_buffer_list, batch_size):
    state = env.reset()
    reward_total = [0. for _ in range(env.uav_num)]
    done = False
 
    while not done:
        """训练主流程代码"""
        action_list = []
        for i, agent in enumerate(agent_list):  # 每个智能体单独决策
            action = agent.take_action(state[i])  # 获取无人机动作，@移植注意更改
            action_list.append(action)
        """环境步进更新，返回下一个状态值、奖励值、环境结束状态等
           @移植注意：返回值格式问题
        """
        next_state, reward, done_, eval_infos = env.step(action_list)
        reward = list(np.array(reward).flatten())
        done = done_[0]
        # print("======action====", action_list)
        # 保存数据，并抽样更新网络
        for j, replay_buffer in enumerate(replay_buffer_list):
            replay_buffer.add(state[j], action_list[j], reward[j], next_state[j], done)  # 为每个智能体的经验池保存数据
            if replay_buffer.size() > batch_size:
                s, a, r, ns, d = replay_buffer.sample(batch_size)  # 抽样
                transition_dict = {
                    'states': s,
                    'actions': a,
                    'next_states': ns,
                    'rewards': r,
                    'dones': d,
                }
                agent_list[j].update(transition_dict)  # 为每个智能体分别更新网络
 
        state = next_state  # 状态变更
        reward_total = [x + y for x, y in zip(reward_total, reward)]  # 奖励统计
 
    return reward_total
 
 
# 测试主程序
def run_test(env, agent_list, show_flag, total_num):
    state = env.reset()
    reward_episode_eva = [0. for _ in range(env.uav_num)]
    done = False
    while not done:
        """测试主流程代码"""
        action_list = []
        for i, agent in enumerate(agent_list):  # 每个智能体单独决策
            action = agent.take_action(state[i])  # 获取无人机动作，@移植注意更改
            action_list.append(action)
        """环境步进更新，返回下一个状态值、奖励值、环境结束状态等
           @移植注意：返回值格式问题
        """
        next_state, reward, done_, eval_infos = env.step(action_list)
        reward = list(np.array(reward).flatten())
        done = done_[0]
        reward_episode_eva = [x + y for x, y in zip(reward_episode_eva, reward)]
        state = next_state
 
    goal_num_buffer = env.get_state_data()
    goal_num_buffer = np.array(goal_num_buffer)
    log_flag = ["state/target_one_num", "state/target_two_num", "state/barrier_crash_num", "state/coverage_ratio"]
    agent_flag = ["agent/agent01", "agent/agent02", "agent/agent03"]
    for i in range(len(goal_num_buffer)):
        goal_num = goal_num_buffer[i]
        goal_num = {log_flag[i]: goal_num}
        log_state(name=log_flag[i], state=goal_num, step=total_num)
 
    for i in range(3):
        agent = reward_episode_eva[i]
        agent = {agent_flag[i]: agent}
        log_state(name=agent_flag[i], state=agent, step=total_num)
 
    reward_total = {"state/reward_total": sum(reward_episode_eva)}
    log_state(name="state/reward_total", state=reward_total, step=total_num)
    return reward_total
 
 
"""tensorboard结果展示函数"""
def log_state(name, state, step):
    writer.add_scalars(name, state, step)
 
 
#  绘图显示函数
def display(return_list, test_reward_list):
    timestamp = time.strftime("%Y%m%d%H%M%S")
    result_path = os.path.dirname(os.path.realpath(__file__)) + '/results/txt/'
    plt.figure(2)
    # 应用平滑处理
    window_length = 31  # 窗口长度，应该是奇数
    polyorder = 2  # 多项式阶数
    smoothed_data_1 = savgol_filter(np.array(return_list), window_length, polyorder)
    plt.plot(smoothed_data_1, label='Smoothed Data', alpha=1)
    plt.plot(list(range(len(np.array(return_list)))), np.array(return_list), label='Original Data', alpha=0.2)
    plt.xlabel('Episodes')
    plt.ylabel('Reward')
    plt.legend()
    plt.title('IQL Train on {}'.format(env_name))
    plt.savefig(result_path + f"train_reward_{timestamp}.png", format='png')
    np.savetxt(result_path + f'/epi_reward_{timestamp}.txt', return_list)
 
    plt.figure(3)
    smoothed_data_2 = savgol_filter(np.array(test_reward_list), window_length, polyorder)
    plt.plot(smoothed_data_2, label='Smoothed Data', alpha=1)
    plt.plot(list(range(len(np.array(test_reward_list)))), np.array(test_reward_list), alpha=0.2, label='Original Data')
    plt.xlabel('Episodes')
    plt.ylabel('Test-Reward')
    # plt.ylim(0, 21000)
    plt.legend()
    plt.title('IQL Test on {}'.format(env_name, env.map_size))
    plt.savefig(result_path + f"test_reward_{timestamp}.png", format='png')
    plt.show()
 
 
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # 环境参数设置
    EnvArgs = parser.parse_args()
    EnvArgs.map_size = 10  # 环境地图大小
    EnvArgs.uav_num = 3  # 无人机数量
    EnvArgs.env_step_time = 50  # 环境运行的离散时间度量
 
    env = Env(args=EnvArgs)  # 环境声明，@移植注意更改
    env_name = "Environment"
    n_states, n_actions = env.get_state_space()[0], env.get_action_space()  # @移植注意更改
 
    capacity = 5000  # 经验池容量
    lr = 6e-3  # 学习率
    gamma = 0.98  # 折扣因子
    epsilon = 0.99  # 贪心系数
    target_update = 1  # 目标网络的参数的更新频率
    batch_size = 64  # 抽取样本数目
    n_hidden = 128  # 隐含层神经元个数
    return_list = []  # 记录每个回合的回报
    test_reward_list = []  # 记录测试的每个回合的回报
    replay_buffer_list = [ReplayBuffer(capacity) for _ in range(env.uav_num)]  # 保存训练数据的列表
    num_episodes = 8000  # 训练次数
    num_test = 8000/5  # 测试次数
    agent_list = []  # 智能体列表
    total_num = 0  # 总的运行次数
 
    for i in range(0, EnvArgs.uav_num):  # 生成多个智能体
        new_agent = Agent(identifier=i, n_states=n_states, n_hidden=n_hidden, n_actions=n_actions, learning_rate=lr,
                          gamma=gamma, epsilon=epsilon, target_update=target_update, device=device)
        agent_list.append(new_agent)
 
    for episode in range(num_episodes):  # 运行主流程
        reward_test = run(env=env, agent_list=agent_list, replay_buffer_list=replay_buffer_list, batch_size=batch_size)
        total_num += 50
        if episode % 5 == 0:  # 测试
            run_test(env=env, agent_list=agent_list, show_flag=False, total_num=total_num)
            print(f"第{episode / 5 + 1}轮测试,奖励值为{reward_test[0]}")
    print(f"总的运行次数{total_num}")
    # 绘图显示
    # display(return_list, test_reward_list)v